﻿#pragma kernel ApplyInertialForces
#pragma kernel ApplyRigidbodyDeltas
#pragma kernel PredictPositions
#pragma kernel UpdateVelocities
#pragma kernel UpdatePositions
#pragma kernel UpdateLifetimes
#pragma kernel EnforceLimits
#pragma kernel Interpolate

#include "Bounds.cginc"
#include "Integration.cginc"
#include "CollisionMaterial.cginc"
#include "SolverParameters.cginc"
#include "MathUtils.cginc"
#include "Rigidbody.cginc"

StructuredBuffer<int> simplices;
StructuredBuffer<int> activeParticles;
StructuredBuffer<float> invMasses;
StructuredBuffer<float> invRotationalMasses;
StructuredBuffer<int> phases;
StructuredBuffer<float4> buoyancies;
StructuredBuffer<float> fluidRadii;

StructuredBuffer<float4> startPositions;
StructuredBuffer<float4> endPositions;
StructuredBuffer<quaternion> startOrientations;
StructuredBuffer<quaternion> endOrientations;

RWStructuredBuffer<float4> positions;
RWStructuredBuffer<quaternion> orientations;
RWStructuredBuffer<float4> principalRadii;
RWStructuredBuffer<float4> renderablePositions;
RWStructuredBuffer<quaternion> renderableOrientations;
RWStructuredBuffer<float4> renderableRadii;
RWStructuredBuffer<float4> prevPositions;
RWStructuredBuffer<quaternion> prevOrientations;
RWStructuredBuffer<float4> velocities;
RWStructuredBuffer<float4> angularVelocities;
RWStructuredBuffer<float> life;
RWStructuredBuffer<float4> wind;

RWStructuredBuffer<int> deadParticles;

RWStructuredBuffer<float4> externalForces;
RWStructuredBuffer<float4> externalTorques;

RWStructuredBuffer<float4> linearDeltas;
RWStructuredBuffer<float4> angularDeltas;

StructuredBuffer<inertialFrame> inertialSolverFrame;

// Variables set from the CPU
uint particleCount;

float deltaTime;
float blendFactor;
float velocityScale;
bool killOffLimits;

float4 angularVel;
float4 inertialAccel;
float4 eulerAccel;
float4 ambientWind;
bool inertialWind;

float4 boundaryLimitsMin;
float4 boundaryLimitsMax;

[numthreads(128, 1, 1)]
void ApplyInertialForces(uint3 id : SV_DispatchThreadID)
{
    unsigned int i = id.x;
    if (i >= particleCount) return;

    int p = activeParticles[i];

    if (invMasses[p] > 0)
    {
        float4 euler = float4(cross(eulerAccel.xyz, positions[p].xyz), 0);
        float4 centrifugal = float4(cross(angularVel.xyz, cross(angularVel.xyz, positions[p].xyz)), 0);
        float4 coriolis = 2 * float4(cross(angularVel.xyz, velocities[p].xyz), 0);
        float4 angularAccel = euler + coriolis + centrifugal;

        velocities[p] -= (inertialAccel * worldLinearInertiaScale + angularAccel * worldAngularInertiaScale) * deltaTime;
    }

    wind[p] = ambientWind;

    if (inertialWind)
    {
        float4 wsPos = inertialSolverFrame[0].frame.TransformPoint(positions[p]);
        wind[p] -= inertialSolverFrame[0].frame.InverseTransformVector(inertialSolverFrame[0].velocityAtPoint(wsPos));
    }
}

[numthreads(128, 1, 1)]
void ApplyRigidbodyDeltas (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
   
    if (i >= particleCount) return;
    
    linearDeltas[i].xyz = float3(asfloat(linearDeltasAsInt[i].x),
                                 asfloat(linearDeltasAsInt[i].y),
                                 asfloat(linearDeltasAsInt[i].z));

    angularDeltas[i].xyz = float3(asfloat(angularDeltasAsInt[i].x),
                                  asfloat(angularDeltasAsInt[i].y),
                                  asfloat(angularDeltasAsInt[i].z));
}

[numthreads(128, 1, 1)]
void PredictPositions (uint3 id : SV_DispatchThreadID)
{
    unsigned int i = id.x;
    if (i >= particleCount) return;

    int p = activeParticles[i];

    // the previous position/orientation is the current position/orientation at the start of the step.
    prevPositions[p] = positions[p];
    prevOrientations[p] = orientations[p];
  
    // predict positions:
    if (invMasses[p] > 0)
    {
        float4 effectiveGravity = float4(gravity,0);

        // Adjust gravity for buoyant fluid particles:
        if ((phases[p] & (int)PHASE_FLUID) != 0)
            effectiveGravity *= -buoyancies[p].z;

        // apply external forces and gravity:
        float4 vel = velocities[p] + (invMasses[p] * externalForces[p] + effectiveGravity) * deltaTime; 

        // project velocity to 2D plane if needed:
        if (mode == 1)
            vel[3] = 0;

        velocities[p] = vel;
    }

    if (invRotationalMasses[p] > 0)
    {
        // apply external torques (simplification: we don't use full inertia tensor here)
        float3 angularVel = angularVelocities[p].xyz + invRotationalMasses[p] * externalTorques[p].xyz * deltaTime;

        // project angular velocity to 2D plane normal if needed:
        if (mode == 1)
            angularVel = Project(angularVel,float3(0, 0, 1));
            
        angularVelocities[p] = float4(angularVel, angularVelocities[p].w);
    }

    positions[p] = IntegrateLinear(positions[p], velocities[p], deltaTime);
    orientations[p] = IntegrateAngular(orientations[p], angularVelocities[p], deltaTime);
}

[numthreads(128, 1, 1)]
void UpdateVelocities (uint3 id : SV_DispatchThreadID)
{
    unsigned int i = id.x;
    if (i >= particleCount) return;

    int p = activeParticles[i];

    // Project particles on the XY plane if we are in 2D mode:
    if (mode == 1)
    {
        // restrict position to the 2D plane
        float4 pos = positions[p];
        pos[2] = prevPositions[p][2];
        positions[p] = pos;
    }

    if (invMasses[p] > 0)
        velocities[p] = DifferentiateLinear(positions[p],prevPositions[p],deltaTime);
    else
        velocities[p] = FLOAT4_ZERO;

    if (invRotationalMasses[p] > 0)
        angularVelocities[p].xyz = DifferentiateAngular(orientations[p], prevOrientations[p], deltaTime).xyz;
    else
        angularVelocities[p] = FLOAT4_ZERO;
}

[numthreads(128, 1, 1)]
void UpdatePositions (uint3 id : SV_DispatchThreadID)
{
    unsigned int i = id.x;
    if (i >= particleCount) return;

    int p = activeParticles[i];

    // damp velocities:
    velocities[p] *= velocityScale;
    angularVelocities[p].xyz *= velocityScale;

    // clamp velocities:
    float velMagnitude = length(velocities[p]);
    float angularVelMagnitude = length(angularVelocities[p].xyz);

    if (velMagnitude > EPSILON)
        velocities[p] *= min(maxVelocity, velMagnitude) / velMagnitude;

    if (angularVelMagnitude > EPSILON)
        angularVelocities[p].xyz *= min(maxAngularVelocity, angularVelMagnitude) / angularVelMagnitude;

    // if the kinetic energy is below the sleep threshold, keep the particle at its previous position.
    if (velMagnitude * velMagnitude * 0.5f + angularVelMagnitude * angularVelMagnitude * 0.5f <= sleepThreshold)
    {
        positions[p] = prevPositions[p];
        orientations[p] = prevOrientations[p];
        velocities[p] = FLOAT4_ZERO;
        angularVelocities[p].xyz = float3(0,0,0);
    }
}

[numthreads(128, 1, 1)]
void UpdateLifetimes (uint3 id : SV_DispatchThreadID)
{
    unsigned int i = id.x;
    if (i >= particleCount) return;

    int p = activeParticles[i];
    
    life[p] -= deltaTime;

    // if particle is dead, append it to array.
    if (life[p] <= 0)
    {
        // atomically increment dead particle counter:
        uint count = deadParticles.IncrementCounter();

        deadParticles[count] = p;
        life[p] = 0;
    }
}

[numthreads(128, 1, 1)]
void EnforceLimits (uint3 id : SV_DispatchThreadID)
{
    unsigned int i = id.x;
    if (i >= particleCount) return;

    int p = activeParticles[i];

    float4 pos = positions[p];
    float4 prevPos = prevPositions[p];

    bool outside = any(step(pos, boundaryLimitsMin).xyz + step(boundaryLimitsMax, pos).xyz);
    
    if ((phases[p] & (int)PHASE_ISOLATED) != 0)
        life[p] = killOffLimits && outside ? 0 : life[p];

    pos.xyz = clamp(pos, boundaryLimitsMin, boundaryLimitsMax).xyz;
    prevPos.xyz = clamp(prevPos, boundaryLimitsMin, boundaryLimitsMax).xyz;

    positions[p] = pos;
    prevPositions[p] = prevPos;
}

[numthreads(128, 1, 1)]
void Interpolate (uint3 id : SV_DispatchThreadID)
{
    unsigned int i = id.x;
    if (i >= particleCount) return;

    if (interpolation == 1)
    {
        renderablePositions[i] = lerp(startPositions[i], endPositions[i], blendFactor);
        renderableOrientations[i] = normalize(q_slerp(startOrientations[i], endOrientations[i], blendFactor));
        renderableRadii[i] = principalRadii[i];
    }
    else if (interpolation == 2)
    {
        renderablePositions[i] = lerp(endPositions[i], positions[i], blendFactor);
        renderableOrientations[i] = normalize(q_slerp(endOrientations[i], orientations[i], blendFactor));
        renderableRadii[i] = principalRadii[i];
    }
    else
    {
        renderablePositions[i] = endPositions[i];
        renderableOrientations[i] = normalize(endOrientations[i]);
        renderableRadii[i] = principalRadii[i];
    }
}